Skip to content

[LTX-2] Run Gemma-3 Text Encoder natively in JAX via TorchAX#398

Open
mbohlool wants to merge 1 commit intomainfrom
text_encoder_tpu3
Open

[LTX-2] Run Gemma-3 Text Encoder natively in JAX via TorchAX#398
mbohlool wants to merge 1 commit intomainfrom
text_encoder_tpu3

Conversation

@mbohlool
Copy link
Copy Markdown
Collaborator

@mbohlool mbohlool commented May 4, 2026

Description

This PR transitions the LTX-2 pipeline's text encoding process to utilize TorchAX, bridging the Gemma-3 model natively into JAX and significantly optimizing memory usage to prevent TPU out-of-memory errors. Minor PyLint warnings across the pipeline were also resolved during the refactor.

Key changes include:

  • TorchAX Integration: Replaced the eager PyTorch-based text encoder execution with the JAX-native TorchaxGemma3TextEncoder. TPU sharding is now manually distributed across the batch dimension via jax.device_put to prevent Softmax OOM crashes.
  • VAE Memory Optimization: Updated the VAE decoding loop to conditionally apply sharding constraints. By disabling sequential slicing and dynamically adjusting batch sharding for batch_size > 2, HBM crashes during decoding are avoided.
  • Lint Cleanup: Addressed minor PyLint warnings in the pipeline and encoder wrapper to maintain code health.

Benchmarks

Performance comparison demonstrating latency improvements from TorchAX integration.

Configuration Text Encoding (CPU) Text Encoding (TorchAX) Text Encoding Impr. Total Time (TE on CPU) Total Time (TE on TorchAX) Generation Impr.
Batch Size 1 (Latency Optimized) 3.75s 2.52s 32.93% 13.19s 11.67s 11.47%
Batch Size 1 (w/ Upsampler) 3.57s 2.47s 30.72% 16.65s 15.61s 6.28%
Batch Size 8 (Throughput Optimized) 23.23s 5.86s 74.77% 80.14s 60.40s 24.64%
Batch Size 8 (w/ Upsampler) 23.36s 6.10s 73.87% 114.98s 86.74s 24.56%

@mbohlool mbohlool requested a review from entrpn as a code owner May 4, 2026 20:08
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 4, 2026

@mbohlool mbohlool force-pushed the text_encoder_tpu3 branch 4 times, most recently from 252c34e to a449a5c Compare May 6, 2026 07:39
@mbohlool mbohlool force-pushed the text_encoder_tpu3 branch from a449a5c to d681d61 Compare May 6, 2026 07:44
text_input_ids = jnp.array(text_inputs.input_ids)
prompt_attention_mask = jnp.array(text_inputs.attention_mask)

# Distribute the batch dimension across available TPUs to prevent Softmax OOM
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we tested in trillium? Wondering since we have less HBM in trillium, will it cause OOM issues? If so we might have consider tp

)
text_encoder.eval()

with default_env():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a config param for users to determine whether to use text encoder on cpu vs tpu? Might be useful when dealing with older chips with lower HBM

@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR successfully integrates TorchAX for the LTX-2 pipeline's text encoder, bringing significant performance improvements and memory optimizations on TPU. The transition from eager PyTorch to JAX-native execution is well-implemented, and the additional sharding constraints for both the text encoder and VAE are effective strategies for preventing OOM crashes.

🔍 General Feedback

  • TorchAX Integration: The use of TorchaxGemma3TextEncoder and the manual batch sharding logic is a great addition for efficiency.
  • Memory Management: The conditional sharding and slicing disabling in the VAE decoding loop correctly addresses HBM issues for larger batches.
  • Distributed Performance: One critical observation is the explicit un-sharding of text encoder hidden states to a single device, which should be avoided to ensure optimal performance in multi-host environments.
  • Code Cleanliness: Small refactors to use getattr instead of broad try/except blocks will improve maintainability.

Comment on lines +419 to 422
except Exception: # pylint: disable=broad-exception-caught
replicate_vae = False
if replicate_vae:
sharding = NamedSharding(mesh, P())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Using getattr is cleaner than a try/except block for checking optional config parameters.

Suggested change
except Exception: # pylint: disable=broad-exception-caught
replicate_vae = False
if replicate_vae:
sharding = NamedSharding(mesh, P())
replicate_vae = getattr(config, "replicate_vae", False)

Comment on lines +467 to 470
except Exception: # pylint: disable=broad-exception-caught
replicate_vae = False
if replicate_vae:
sharding = NamedSharding(mesh, P())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Same as above, getattr is preferred for optional configuration attributes.

Suggested change
except Exception: # pylint: disable=broad-exception-caught
replicate_vae = False
if replicate_vae:
sharding = NamedSharding(mesh, P())
replicate_vae = getattr(config, "replicate_vae", False)

Comment on lines +808 to 811
state = jax.device_put(state, jax.devices()[0])
prompt_embeds_list.append(state.astype(jnp.bfloat16))

prompt_embeds = prompt_embeds_list
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 This call explicitly un-shards the hidden states to a single device (global device 0). This is detrimental in multi-host setups as it forces cross-host communication and can cause a memory bottleneck on the first TPU. Since the hidden states are already JAX arrays computed on the TPU mesh (and will be correctly resharded to the transformer's mesh at line 1367), this step should be removed to maintain efficient distributed execution.

Suggested change
state = jax.device_put(state, jax.devices()[0])
prompt_embeds_list.append(state.astype(jnp.bfloat16))
prompt_embeds = prompt_embeds_list
for state in text_encoder_hidden_states:
prompt_embeds_list.append(state.astype(jnp.bfloat16))

for i in range(len(timesteps_jax)):
t = timesteps_jax[i]

for _, t in enumerate(timesteps_jax):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 The index _ is not used. You can simplify this to a direct iteration.

Suggested change
for _, t in enumerate(timesteps_jax):
for t in timesteps_jax:

return inner_mask


transformers.masking_utils.sliding_window_overlay = _patched_sliding_window_overlay
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Global monkeypatching of transformers.masking_utils can have unintended side effects if other models in the same process rely on the original behavior. While this is a necessary workaround for TorchAX + Gemma-3, consider documenting the sequence length assumption more explicitly or ensuring this patch doesn't break other potential future Gemma-based models in the same environment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants